# # Copyright (c) Alibaba Cloud.
# #
# # This source code is licensed under the license found in the
# # LICENSE file in the root directory of this source tree.

# """A simple web interactive chat demo based on gradio."""
# import os
# from argparse import ArgumentParser

# import gradio as gr
# import mdtex2html

# import torch
# from transformers import AutoModelForCausalLM, AutoTokenizer
# from transformers.generation import GenerationConfig


# DEFAULT_CKPT_PATH = '/mnt/lth/llm/train/Qwen-main/output_qwen_25_0.5b_instruct/checkpoint-8000'


# def _get_args():
#     parser = ArgumentParser()
#     parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
#                         help="Checkpoint name or path, default to %(default)r")
#     parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")

#     parser.add_argument("--share", action="store_true", default=False,
#                         help="Create a publicly shareable link for the interface.")
#     parser.add_argument("--inbrowser", action="store_true", default=False,
#                         help="Automatically launch the interface in a new tab on the default browser.")
#     parser.add_argument("--server-port", type=int, default=8000,
#                         help="Demo server port.")
#     parser.add_argument("--server-name", type=str, default="0.0.0.0",
#                         help="Demo server name.")

#     args = parser.parse_args()
#     return args


# def _load_model_tokenizer(args):
#     tokenizer = AutoTokenizer.from_pretrained(
#         args.checkpoint_path, trust_remote_code=True, resume_download=True,
#     )

#     if args.cpu_only:
#         device_map = "cpu"
#     else:
#         device_map = "auto"

#     model = AutoModelForCausalLM.from_pretrained(
#         args.checkpoint_path,
#         device_map=device_map,
#         trust_remote_code=True,
#         resume_download=True,
#     ).eval()

#     config = GenerationConfig.from_pretrained(
#         args.checkpoint_path, trust_remote_code=True, resume_download=True,
#     )

#     return model, tokenizer, config


# def postprocess(self, y):
#     if y is None:
#         return []
#     for i, (message, response) in enumerate(y):
#         y[i] = (
#             None if message is None else mdtex2html.convert(message),
#             None if response is None else mdtex2html.convert(response),
#         )
#     return y


# gr.Chatbot.postprocess = postprocess


# def _parse_text(text):
#     lines = text.split("\n")
#     lines = [line for line in lines if line != ""]
#     count = 0
#     for i, line in enumerate(lines):
#         if "```" in line:
#             count += 1
#             items = line.split("`")
#             if count % 2 == 1:
#                 lines[i] = f'<pre><code class="language-{items[-1]}">'
#             else:
#                 lines[i] = f"<br></code></pre>"
#         else:
#             if i > 0:
#                 if count % 2 == 1:
#                     line = line.replace("`", r"\`")
#                     line = line.replace("<", "&lt;")
#                     line = line.replace(">", "&gt;")
#                     line = line.replace(" ", "&nbsp;")
#                     line = line.replace("*", "&ast;")
#                     line = line.replace("_", "&lowbar;")
#                     line = line.replace("-", "&#45;")
#                     line = line.replace(".", "&#46;")
#                     line = line.replace("!", "&#33;")
#                     line = line.replace("(", "&#40;")
#                     line = line.replace(")", "&#41;")
#                     line = line.replace("$", "&#36;")
#                 lines[i] = "<br>" + line
#     text = "".join(lines)
#     return text


# def _gc():
#     import gc
#     gc.collect()
#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()


# def _launch_demo(args, model, tokenizer, config):

#     # def predict(_query, _chatbot, _task_history):
#     #     print(f"User: {_parse_text(_query)}")
#     #     _chatbot.append((_parse_text(_query), ""))
#     #     full_response = ""

#     #     for response in model.chat_stream(tokenizer, _query, history=_task_history, generation_config=config):
#     #         _chatbot[-1] = (_parse_text(_query), _parse_text(response))

#     #         yield _chatbot
#     #         full_response = _parse_text(response)

#     #     print(f"History: {_task_history}")
#     #     # _task_history.append((_query, full_response))
#     #     print(f"Qwen-Chat: {_parse_text(full_response)}")
#     def predict(_query, _chatbot, _task_history):
#         print(f"学生: {_parse_text(_query)}")
#         _chatbot.append((_parse_text(_query), ""))
        
#         # 拼接对话历史
#         full_prompt = "\n".join([h[0] + h[1] for h in _task_history]) + _query
#         inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
        
#         # 生成响应
#         output_ids = model.generate(
#             **inputs,
#             max_length=512,  # 你可以调整
#             temperature=0.7,  # 控制生成的随机性
#             top_p=0.9  # 控制生成的多样性
#         )

#         full_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
#         # 更新 _chatbot
#         _chatbot[-1] = (_parse_text(_query), _parse_text(full_response))

#         yield _chatbot

#         print(f"History: {_task_history}")
#         # _task_history.append((_query, full_response))  # 你可以选择启用这行代码
#         print(f"Qwen-Chat: {_parse_text(full_response)}")

#     def regenerate(_chatbot, _task_history):
#         if not _task_history:
#             yield _chatbot
#             return
#         item = _task_history.pop(-1)
#         _chatbot.pop(-1)
#         yield from predict(item[0], _chatbot, _task_history)

#     def reset_user_input():
#         return gr.update(value="")

#     def reset_state(_chatbot, _task_history):
#         _task_history.clear()
#         _chatbot.clear()
#         _gc()
#         return _chatbot

#     with gr.Blocks() as demo:
#         gr.Markdown("""<center><font size=8>Qwen-7B-Chat-SFT Bot</center>""")
#         gr.Markdown(
#             """\
# <center><font size=3>This WebUI is based on Qwen-Chat, developed by Alibaba Cloud. \
# (本WebUI基于Qwen-Chat打造，实现聊天机器人功能。)</center>""")
# #         gr.Markdown("""\
# # <center><font size=4>
# # Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 </a> |
# # <a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>&nbsp ｜
# # Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 </a> |
# # <a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>&nbsp ｜
# # Qwen-14B <a href="https://modelscope.cn/models/qwen/Qwen-14B/summary">🤖 </a> |
# # <a href="https://huggingface.co/Qwen/Qwen-14B">🤗</a>&nbsp ｜
# # Qwen-14B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary">🤖 </a> |
# # <a href="https://huggingface.co/Qwen/Qwen-14B-Chat">🤗</a>&nbsp ｜
# # &nbsp<a href="https://github.com/QwenLM/Qwen">Github</a></center>""")

#         chatbot = gr.Chatbot(label='Qwen-Chat', elem_classes="control-height")
#         query = gr.Textbox(lines=2, label='Input')
#         task_history = gr.State([])

#         with gr.Row():
#             empty_btn = gr.Button("🧹 Clear History (清除历史)")
#             submit_btn = gr.Button("🚀 Submit (发送)")
#             regen_btn = gr.Button("🤔️ Regenerate (重试)")

#         submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True)
#         submit_btn.click(reset_user_input, [], [query])
#         empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True)
#         regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)

#     demo.queue().launch(
#         share=args.share,
#         inbrowser=args.inbrowser,
#         server_port=args.server_port,
#         server_name=args.server_name,
#     )


# def main():
#     args = _get_args()

#     model, tokenizer, config = _load_model_tokenizer(args)

#     _launch_demo(args, model, tokenizer, config)


# if __name__ == '__main__':
#     main()







# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""A simple web interactive chat demo based on gradio."""
import os
from argparse import ArgumentParser

import gradio as gr
import mdtex2html

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig


# DEFAULT_CKPT_PATH = '/mnt/lth/llm/train/Qwen-main/output_qwen_25_0.5b_instruct/checkpoint-8000'
DEFAULT_CKPT_PATH = '/mnt/lth/llm/train/Qwen-main/output_qwen_25_0.5b_instruct/checkpoint-5000'
# DEFAULT_CKPT_PATH = '/mnt/lth/llm/train/Qwen-main/models/Qwen/Qwen2___5-7B'


def _get_args():
    parser = ArgumentParser()
    parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
                        help="Checkpoint name or path, default to %(default)r")
    parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")

    parser.add_argument("--share", action="store_true", default=False,
                        help="Create a publicly shareable link for the interface.")
    parser.add_argument("--inbrowser", action="store_true", default=False,
                        help="Automatically launch the interface in a new tab on the default browser.")
    parser.add_argument("--server-port", type=int, default=8000,
                        help="Demo server port.")
    parser.add_argument("--server-name", type=str, default="0.0.0.0",
                        help="Demo server name.")

    args = parser.parse_args()
    return args


def _load_model_tokenizer(args):
    tokenizer = AutoTokenizer.from_pretrained(
        args.checkpoint_path, trust_remote_code=True, resume_download=True,
    )

    if args.cpu_only:
        device_map = "cpu"
    else:
        device_map = "auto"

    model = AutoModelForCausalLM.from_pretrained(
        args.checkpoint_path,
        device_map=device_map,
        trust_remote_code=True,
        resume_download=True,
        torch_dtype=torch.bfloat16
    ).eval()

    config = GenerationConfig.from_pretrained(
        args.checkpoint_path, trust_remote_code=True, resume_download=True,
    )

    return model, tokenizer, config


def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert(message),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess


def _parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line != ""]
    count = 0
    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split("`")
            if count % 2 == 1:
                lines[i] = f'<pre><code class="language-{items[-1]}">'
            else:
                lines[i] = f"<br></code></pre>"
        else:
            if i > 0:
                if count % 2 == 1:
                    line = line.replace("`", r"\`")
                    line = line.replace("<", "&lt;")
                    line = line.replace(">", "&gt;")
                    line = line.replace(" ", "&nbsp;")
                    line = line.replace("*", "&ast;")
                    line = line.replace("_", "&lowbar;")
                    line = line.replace("-", "&#45;")
                    line = line.replace(".", "&#46;")
                    line = line.replace("!", "&#33;")
                    line = line.replace("(", "&#40;")
                    line = line.replace(")", "&#41;")
                    line = line.replace("$", "&#36;")
                lines[i] = "<br>" + line
    text = "".join(lines)
    return text


def _gc():
    import gc
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


def _launch_demo(args, model, tokenizer, config):
    SYSTEM_PROMPT = "你是-个小学数学老师."

    def predict(_query, _chatbot, _task_history):
        print(f"学生: {_parse_text(_query)}")
        _chatbot.append((_parse_text(_query), ""))

        # 拼接系统提示和对话历史
        full_prompt = f"system\n{SYSTEM_PROMPT}\n"
        for history_query, history_response in _task_history:
            full_prompt += f"学生\n{history_query}\老师\n{history_response}\n"
        full_prompt += f"学生\n{_query}"

        inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)

        # 生成响应
        output_ids = model.generate(
            **inputs,
            max_length=512,  # 你可以调整
            temperature=0.7,  # 控制生成的随机性
            top_p=0.9  # 控制生成的多样性
        )

        full_response = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        # 去除系统提示和输入部分
        full_response = full_response.replace(full_prompt, "").strip()
        print("老师：",full_response)
        # 更新 _chatbot
        _chatbot[-1] = (_parse_text(_query), _parse_text(full_response))

        yield _chatbot

        # print(f"History: {_task_history}")
        # # _task_history.append((_query, full_response))  # 你可以选择启用这行代码
        # print(f"Qwen-Chat: {_parse_text(full_response)}")

    def regenerate(_chatbot, _task_history):
        if not _task_history:
            yield _chatbot
            return
        item = _task_history.pop(-1)
        _chatbot.pop(-1)
        yield from predict(item[0], _chatbot, _task_history)

    def reset_user_input():
        return gr.update(value="")

    def reset_state(_chatbot, _task_history):
        _task_history.clear()
        _chatbot.clear()
        _gc()
        return _chatbot

    with gr.Blocks() as demo:
        gr.Markdown("""<center><font size=8>Qwen25-0.5B-Chat-SFT Bot</center>""")
        gr.Markdown(
            """\
<center><font size=3>This WebUI is based on Qwen-Chat, developed by Alibaba Cloud. \
(本WebUI基于Qwen-Chat打造，实现聊天机器人功能。)</center>""")
        #         gr.Markdown("""\
        # <center><font size=4>
        # Qwen-7B <a href="https://modelscope.cn/models/qwen/Qwen-7B/summary">🤖 </a> |
        # <a href="https://huggingface.co/Qwen/Qwen-7B">🤗</a>&nbsp ｜
        # Qwen-7B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-7B-Chat/summary">🤖 </a> |
        # <a href="https://huggingface.co/Qwen/Qwen-7B-Chat">🤗</a>&nbsp ｜
        # Qwen-14B <a href="https://modelscope.cn/models/qwen/Qwen-14B/summary">🤖 </a> |
        # <a href="https://huggingface.co/Qwen/Qwen-14B">🤗</a>&nbsp ｜
        # Qwen-14B-Chat <a href="https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary">🤖 </a> |
        # <a href="https://huggingface.co/Qwen/Qwen-14B-Chat">🤗</a>&nbsp ｜
        # &nbsp<a href="https://github.com/QwenLM/Qwen">Github</a></center>""")

        chatbot = gr.Chatbot(label='Qwen-Chat', elem_classes="control-height")
        query = gr.Textbox(lines=2, label='Input')
        task_history = gr.State([])

        with gr.Row():
            empty_btn = gr.Button("🧹 Clear History (清除历史)")
            submit_btn = gr.Button("🚀 Submit (发送)")
            regen_btn = gr.Button("🤔️ Regenerate (重试)")

        submit_btn.click(predict, [query, chatbot, task_history], [chatbot], show_progress=True)
        submit_btn.click(reset_user_input, [], [query])
        empty_btn.click(reset_state, [chatbot, task_history], outputs=[chatbot], show_progress=True)
        regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)

    demo.queue().launch(
        share=True,
        inbrowser=args.inbrowser,
        server_port=args.server_port,
        server_name=args.server_name,
    )


def main():
    args = _get_args()

    model, tokenizer, config = _load_model_tokenizer(args)

    _launch_demo(args, model, tokenizer, config)


if __name__ == '__main__':
    main()